---
title: "Study shrinkage and DART in xgboost modeling using a simple dataset"
author: "Yang Liu"
date: "2018-11-15"
description: "A small R experiment that illustrates how shrinkage and DART dropout change the behavior of XGBoost trees on a simple toy dataset."
categories:
- "Machine Learning"
tags:
- "XGBoost"
- "DART"
page-layout: article
execute:
freeze: true
eval: false
resources:
- "source.Rmd"
- "index_files/**"
- "images/**"
- "temp/**"
- "2018-11-15-xgboost-dart-test-using-a-simple-tree_files/**"
- "*.png"
- "*.jpg"
- "*.jpeg"
- "*.JPG"
- "*.PNG"
- "*.gif"
- "*.svg"
- "*.rds"
- "*.csv"
- "*.xlsx"
---
<script src="/rmarkdown-libs/htmlwidgets/htmlwidgets.js"></script>
<script src="/rmarkdown-libs/viz/viz.js"></script>
<link href="/rmarkdown-libs/DiagrammeR-styles/styles.css" rel="stylesheet" />
<script src="/rmarkdown-libs/grViz-binding/grViz.js"></script>
<div id="TOC">
<ul>
<li><a href="#data">Data</a></li>
<li><a href="#shrinkage">Shrinkage</a></li>
<li><a href="#dart-dropout---mart">DART: Dropout - MART</a><ul>
<li><a href="#skip_drop"><code>skip_drop</code></a></li>
<li><a href="#rate_drop"><code>rate_drop</code></a></li>
<li><a href="#one_drop"><code>one_drop</code></a></li>
</ul></li>
</ul>
</div>
<p>A small dataset is often the best way to understand a modeling algorithm. Inspired by <a href="http://arfer.net/w/xgboost-sparsity">my colleague Kodi’s excellent work showing how <code>xgboost</code> handles missing values</a>, I used a simple 5x2 dataset to show how shrinkage and DART influence tree growth in an XGBoost model.</p>
<div id="data" class="section level1">
<h1>Data</h1>
<pre class="r"><code>set.seed(123)
n0 <- 5
X <- data.frame(x1 = runif(n0), x2 = runif(n0))
Y <- c(1, 5, 20, 50, 100)
cbind(X, Y)</code></pre>
<pre><code>## x1 x2 Y
## 1 0.2876 0.04556 1
## 2 0.7883 0.52811 5
## 3 0.4090 0.89242 20
## 4 0.8830 0.55144 50
## 5 0.9405 0.45661 100</code></pre>
</div>
<div id="shrinkage" class="section level1">
<h1>Shrinkage</h1>
<ol style="list-style-type: decimal">
<li>Step-size shrinkage is a major tool for preventing overfitting, or over-specialization.<br />
</li>
</ol>
<ul>
<li><a href="https://xgboost.readthedocs.io/en/latest/parameter.html">The R documentation</a> says that the learning rate <code>eta</code> has range [0, 1], but <code>xgboost</code> takes any value of <span class="math inline">\(eta\ge0\)</span>. Here I select <code>eta = 2</code>, and the model can perfectly predict in two steps: the train RMSE from iteration 2 is 0, and only two trees are used.<br />
Of course, it is a bad idea to use a very large <em>eta</em> in real applications as the tree will not be helpful and the predicted value will be very wrong.<br />
</li>
<li>The <code>max_depth</code> is the maximum depth of a tree, I set 10 but it won’t be reached<br />
</li>
<li>By default, there is no other regularization</li>
<li>By setting <code>base_score = 0</code> in <code>xgboost</code> we can add up the values in the leaves of two trees to get every number in Y: 1, 5, 20, 50, 100</li>
</ul>
<pre class="r"><code># non-zero skip_drop has higher priority than rate_drop or one_drop
param_gbtree <- list(objective = 'reg:linear', nrounds = 3,
eta = 2,
max_depth = 10,
min_child_weight = 0,
booster = 'gbtree'
)
simple.xgb.output <- function(param,...){
set.seed(1234)
m = xgboost(data = as.matrix(X), label = Y, params = param,
nrounds = param$nround,
base_score = 0)
cat('Evaluation log showing testing error:\n')
print(m$evaluation_log)
pred <- predict(m, as.matrix(X), ntreelimit = param$nrounds)
cat('Predicted values of Y: \n')
print(pred)
pred2 <- predict(m, as.matrix(X), predcontrib = TRUE)
cat("SHAP value for X: \n")
print(pred2)
p <- xgb.plot.tree(model = m)
p
}
simple.xgb.output(param_gbtree)</code></pre>
<pre><code>## [1] train-rmse:22.360680
## [2] train-rmse:0.000000
## [3] train-rmse:0.000000
## Evaluation log showing testing error:
## iter train_rmse
## 1: 1 22.36
## 2: 2 0.00
## 3: 3 0.00
## Predicted values of Y:
## [1] 1 5 20 50 100
## SHAP value for X:
## x1 x2 BIAS
## [1,] -29.33 -4.867 35.2
## [2,] -26.00 -4.200 35.2
## [3,] -24.93 9.733 35.2
## [4,] 16.50 -1.700 35.2
## [5,] 66.50 -1.700 35.2</code></pre>
<div id="htmlwidget-1" style="width:672px;height:480px;" class="grViz html-widget"></div>
<script type="application/json" data-for="htmlwidget-1">{"x":{"diagram":"digraph {\n\ngraph [layout = \"dot\",\n rankdir = \"LR\"]\n\nnode [color = \"DimGray\",\n style = \"filled\",\n fontname = \"Helvetica\"]\n\nedge [color = \"DimGray\",\n arrowsize = \"1.5\",\n arrowhead = \"vee\",\n fontname = \"Helvetica\"]\n\n \"1\" [label = \"Tree 2\nLeaf\nCover: 5\nValue: 0\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"2\" [label = \"Tree 1\nx1\nCover: 5\nGain: 416.666656\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"3\" [label = \"Leaf\nCover: 3\nValue: 0\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"4\" [label = \"x1\nCover: 2\nGain: 416.666687\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"5\" [label = \"Leaf\nCover: 1\nValue: -50\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"6\" [label = \"Leaf\nCover: 1\nValue: 0\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"7\" [label = \"Tree 0\nx1\nCover: 5\nGain: 2506.3335\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"8\" [label = \"x2\nCover: 3\nGain: 43\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"9\" [label = \"Leaf\nCover: 2\nValue: 100\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"10\" [label = \"x1\nCover: 2\nGain: 1\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"11\" [label = \"Leaf\nCover: 1\nValue: 20\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"12\" [label = \"Leaf\nCover: 1\nValue: 1\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"13\" [label = \"Leaf\nCover: 1\nValue: 5\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n\"2\"->\"3\" [label = \"< 0.835661292\", style = \"bold\"] \n\"4\"->\"5\" [label = \"< 0.91174233\", style = \"bold\"] \n\"7\"->\"8\" [label = \"< 0.835661292\", style = \"bold\"] \n\"8\"->\"10\" [label = \"< 0.710262299\", style = \"bold\"] \n\"10\"->\"12\" [label = \"< 0.537941337\", style = \"bold\"] \n\"2\"->\"4\" [style = \"bold\", style = \"solid\"] \n\"4\"->\"6\" [style = \"solid\", style = \"solid\"] \n\"7\"->\"9\" [style = \"solid\", style = \"solid\"] \n\"8\"->\"11\" [style = \"solid\", style = \"solid\"] \n\"10\"->\"13\" [style = \"solid\", style = \"solid\"] \n}","config":{"engine":"dot","options":null}},"evals":[],"jsHooks":[]}</script>
<ul>
<li>If <code>eta</code> = 1<br />
Then no perfect prediction can be made, and the trees grow in a more conservative manner.</li>
</ul>
<pre class="r"><code># non-zero skip_drop has higher priority than rate_drop or one_drop
param_gbtree <- list(objective = 'reg:linear', nrounds = 3,
eta = 1,
max_depth = 10,
min_child_weight = 0,
booster = 'gbtree'
)
simple.xgb.output(param_gbtree)</code></pre>
<pre><code>## [1] train-rmse:22.831995
## [2] train-rmse:11.415998
## [3] train-rmse:5.707999
## Evaluation log showing testing error:
## iter train_rmse
## 1: 1 22.832
## 2: 2 11.416
## 3: 3 5.708
## Predicted values of Y:
## [1] 0.875 4.375 17.500 50.000 87.500
## SHAP value for X:
## x1 x2 BIAS
## [1,] -27.18 -3.999 32.05
## [2,] -24.20 -3.478 32.05
## [3,] -24.11 9.564 32.05
## [4,] 20.41 -2.463 32.05
## [5,] 56.98 -1.525 32.05</code></pre>
<div id="htmlwidget-2" style="width:672px;height:480px;" class="grViz html-widget"></div>
<script type="application/json" data-for="htmlwidget-2">{"x":{"diagram":"digraph {\n\ngraph [layout = \"dot\",\n rankdir = \"LR\"]\n\nnode [color = \"DimGray\",\n style = \"filled\",\n fontname = \"Helvetica\"]\n\nedge [color = \"DimGray\",\n arrowsize = \"1.5\",\n arrowhead = \"vee\",\n fontname = \"Helvetica\"]\n\n \"1\" [label = \"Tree 2\nx1\nCover: 5\nGain: 155.575012\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"2\" [label = \"x2\nCover: 4\nGain: 4.61250019\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"3\" [label = \"Leaf\nCover: 1\nValue: 12.5\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"4\" [label = \"x1\nCover: 3\nGain: 0.1875\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"5\" [label = \"Leaf\nCover: 1\nValue: 2.5\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"6\" [label = \"x1\nCover: 2\nGain: 0.0625\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"7\" [label = \"Leaf\nCover: 1\nValue: 0\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"8\" [label = \"Leaf\nCover: 1\nValue: 0.125\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"9\" [label = \"Leaf\nCover: 1\nValue: 0.625\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"10\" [label = \"Tree 1\nx1\nCover: 5\nGain: 622.300049\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"11\" [label = \"x2\nCover: 4\nGain: 18.4500008\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"12\" [label = \"Leaf\nCover: 1\nValue: 25\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"13\" [label = \"x1\nCover: 3\nGain: 0.75\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"14\" [label = \"Leaf\nCover: 1\nValue: 5\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"15\" [label = \"x1\nCover: 2\nGain: 0.25\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"16\" [label = \"Leaf\nCover: 1\nValue: 0\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"17\" [label = \"Leaf\nCover: 1\nValue: 0.25\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"18\" [label = \"Leaf\nCover: 1\nValue: 1.25\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"19\" [label = \"Tree 0\nx1\nCover: 5\nGain: 2506.3335\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"20\" [label = \"x2\nCover: 3\nGain: 43\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"21\" [label = \"Leaf\nCover: 2\nValue: 50\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"22\" [label = \"x1\nCover: 2\nGain: 1\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"23\" [label = \"Leaf\nCover: 1\nValue: 10\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"24\" [label = \"Leaf\nCover: 1\nValue: 0.5\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"25\" [label = \"Leaf\nCover: 1\nValue: 2.5\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n\"1\"->\"2\" [label = \"< 0.91174233\", style = \"bold\"] \n\"2\"->\"4\" [label = \"< 0.721927047\", style = \"bold\"] \n\"4\"->\"6\" [label = \"< 0.835661292\", style = \"bold\"] \n\"6\"->\"8\" [label = \"< 0.537941337\", style = \"bold\"] \n\"10\"->\"11\" [label = \"< 0.91174233\", style = \"bold\"] \n\"11\"->\"13\" [label = \"< 0.721927047\", style = \"bold\"] \n\"13\"->\"15\" [label = \"< 0.835661292\", style = \"bold\"] \n\"15\"->\"17\" [label = \"< 0.537941337\", style = \"bold\"] \n\"19\"->\"20\" [label = \"< 0.835661292\", style = \"bold\"] \n\"20\"->\"22\" [label = \"< 0.710262299\", style = \"bold\"] \n\"22\"->\"24\" [label = \"< 0.537941337\", style = \"bold\"] \n\"1\"->\"3\" [style = \"bold\", style = \"solid\"] \n\"2\"->\"5\" [style = \"solid\", style = \"solid\"] \n\"4\"->\"7\" [style = \"solid\", style = \"solid\"] \n\"6\"->\"9\" [style = \"solid\", style = \"solid\"] \n\"10\"->\"12\" [style = \"solid\", style = \"solid\"] \n\"11\"->\"14\" [style = \"solid\", style = \"solid\"] \n\"13\"->\"16\" [style = \"solid\", style = \"solid\"] \n\"15\"->\"18\" [style = \"solid\", style = \"solid\"] \n\"19\"->\"21\" [style = \"solid\", style = \"solid\"] \n\"20\"->\"23\" [style = \"solid\", style = \"solid\"] \n\"22\"->\"25\" [style = \"solid\", style = \"solid\"] \n}","config":{"engine":"dot","options":null}},"evals":[],"jsHooks":[]}</script>
</div>
<div id="dart-dropout---mart" class="section level1">
<h1>DART: Dropout - MART</h1>
<p>DART (<a href="http://proceedings.mlr.press/v38/korlakaivinayak15.pdf">paper in JMLR</a>) adapts dropout from neural networks to boosted regression trees, or MART: Multiple Additive Regression Trees. DART aims to further prevent over-specialization. In <code>xgboost</code>, use <code>booster = 'dart'</code> and tune the related hyperparameters (<a href="https://xgboost.readthedocs.io/en/latest/parameter.html#additional-parameters-for-dart-booster-booster-dart">R documentation</a>).</p>
<div id="skip_drop" class="section level2">
<h2><code>skip_drop</code></h2>
<ul>
<li>The <code>skip_drop</code>(default = 0, range [0, 1]) is the probability of skipping dropout. It has a higher priority than other DART parameters. If <code>skip_drop</code> = 1, the dropout procedure would be skipped and <code>dart</code> is the same as <code>gbtree</code>. The setting below gives the same result as the <code>gbtree</code> above (results omitted):</li>
</ul>
<pre class="r"><code>param_gbtree <- list(objective = 'reg:linear', nrounds = 3,
eta = 2,
max_depth = 10,
booster = 'dart',
skip_drop = 1, # = 1 means always skip, = gbtree
rate_drop = 1, # doesn't matter since drop is always skipped
one_drop = 1
)</code></pre>
</div>
<div id="rate_drop" class="section level2">
<h2><code>rate_drop</code></h2>
<ul>
<li>If <code>skip_drop</code><span class="math inline">\(\ne0\)</span>, <code>rate_drop (default = 0, range [0, 1])</code> will drop a fraction of the trees before the model update in every iteration.</li>
<li>The DART paper <a href="http://proceedings.mlr.press/v38/korlakaivinayak15.pdf">JMLR</a> said the dropout makes DART between gbtree and random forest: “If no tree is dropped, DART is the same as MART (<code>gbtree</code>); if all the trees are dropped, DART is no different than random forest.”</li>
<li>If <code>rate_drop</code> = 1 then all the trees are dropped, a random forest of trees is built. In our case of a very simple dataset, the ‘random forest’ just repeats the same tree <code>nrounds</code> times:</li>
</ul>
<pre class="r"><code>param_dart1 <- list(objective = 'reg:linear', nrounds = 3,
eta = 2,
max_depth = 10,
booster = 'dart',
skip_drop = 0,
rate_drop = 1, # doesn't matter since drop is always skipped
one_drop = 1
)
simple.xgb.output(param_dart1)</code></pre>
<pre><code>## [1] train-rmse:22.360680
## [2] train-rmse:16.948286
## [3] train-rmse:19.388212
## Evaluation log showing testing error:
## iter train_rmse
## 1: 1 22.36
## 2: 2 16.95
## 3: 3 19.39
## Predicted values of Y:
## [1] 0.5833 2.9167 11.6667 58.3333 58.3333
## SHAP value for X:
## x1 x2 BIAS
## [1,] -22.94 -2.8389 26.37
## [2,] -21.00 -2.4500 26.37
## [3,] -20.38 5.6778 26.37
## [4,] 32.96 -0.9917 26.37
## [5,] 32.96 -0.9917 26.37</code></pre>
<div id="htmlwidget-3" style="width:672px;height:480px;" class="grViz html-widget"></div>
<script type="application/json" data-for="htmlwidget-3">{"x":{"diagram":"digraph {\n\ngraph [layout = \"dot\",\n rankdir = \"LR\"]\n\nnode [color = \"DimGray\",\n style = \"filled\",\n fontname = \"Helvetica\"]\n\nedge [color = \"DimGray\",\n arrowsize = \"1.5\",\n arrowhead = \"vee\",\n fontname = \"Helvetica\"]\n\n \"1\" [label = \"Tree 2\nx1\nCover: 5\nGain: 2506.3335\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"2\" [label = \"x2\nCover: 3\nGain: 43\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"3\" [label = \"Leaf\nCover: 2\nValue: 100\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"4\" [label = \"x1\nCover: 2\nGain: 1\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"5\" [label = \"Leaf\nCover: 1\nValue: 20\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"6\" [label = \"Leaf\nCover: 1\nValue: 1\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"7\" [label = \"Leaf\nCover: 1\nValue: 5\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"8\" [label = \"Tree 1\nx1\nCover: 5\nGain: 2506.3335\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"9\" [label = \"x2\nCover: 3\nGain: 43\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"10\" [label = \"Leaf\nCover: 2\nValue: 100\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"11\" [label = \"x1\nCover: 2\nGain: 1\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"12\" [label = \"Leaf\nCover: 1\nValue: 20\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"13\" [label = \"Leaf\nCover: 1\nValue: 1\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"14\" [label = \"Leaf\nCover: 1\nValue: 5\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"15\" [label = \"Tree 0\nx1\nCover: 5\nGain: 2506.3335\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"16\" [label = \"x2\nCover: 3\nGain: 43\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"17\" [label = \"Leaf\nCover: 2\nValue: 100\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"18\" [label = \"x1\nCover: 2\nGain: 1\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"19\" [label = \"Leaf\nCover: 1\nValue: 20\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"20\" [label = \"Leaf\nCover: 1\nValue: 1\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"21\" [label = \"Leaf\nCover: 1\nValue: 5\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n\"1\"->\"2\" [label = \"< 0.835661292\", style = \"bold\"] \n\"2\"->\"4\" [label = \"< 0.710262299\", style = \"bold\"] \n\"4\"->\"6\" [label = \"< 0.537941337\", style = \"bold\"] \n\"8\"->\"9\" [label = \"< 0.835661292\", style = \"bold\"] \n\"9\"->\"11\" [label = \"< 0.710262299\", style = \"bold\"] \n\"11\"->\"13\" [label = \"< 0.537941337\", style = \"bold\"] \n\"15\"->\"16\" [label = \"< 0.835661292\", style = \"bold\"] \n\"16\"->\"18\" [label = \"< 0.710262299\", style = \"bold\"] \n\"18\"->\"20\" [label = \"< 0.537941337\", style = \"bold\"] \n\"1\"->\"3\" [style = \"bold\", style = \"solid\"] \n\"2\"->\"5\" [style = \"solid\", style = \"solid\"] \n\"4\"->\"7\" [style = \"solid\", style = \"solid\"] \n\"8\"->\"10\" [style = \"solid\", style = \"solid\"] \n\"9\"->\"12\" [style = \"solid\", style = \"solid\"] \n\"11\"->\"14\" [style = \"solid\", style = \"solid\"] \n\"15\"->\"17\" [style = \"solid\", style = \"solid\"] \n\"16\"->\"19\" [style = \"solid\", style = \"solid\"] \n\"18\"->\"21\" [style = \"solid\", style = \"solid\"] \n}","config":{"engine":"dot","options":null}},"evals":[],"jsHooks":[]}</script>
</div>
<div id="one_drop" class="section level2">
<h2><code>one_drop</code></h2>
<ul>
<li>If <code>one_drop</code> = 1 then at least one tree is always dropped. If I let <code>rate_drop</code>=0, but <code>one_drop</code> = 1, the dropping was still working, and the trees were built in a more conservative manner. Since the first tree will be dropped, the second tree is the same as the first one</li>
</ul>
<pre class="r"><code>param_dart2 <- list(objective = 'reg:linear', nrounds = 5,
eta = 2,
max_depth = 10,
booster = 'dart',
skip_drop = 0,
rate_drop = 0, # doesn't matter since drop is always skipped
one_drop = 1
)
simple.xgb.output(param_dart2)</code></pre>
<pre><code>## [1] train-rmse:22.360680
## [2] train-rmse:16.948286
## [3] train-rmse:15.066435
## [4] train-rmse:15.237643
## [5] train-rmse:14.225216
## Evaluation log showing testing error:
## iter train_rmse
## 1: 1 22.36
## 2: 2 16.95
## 3: 3 15.07
## 4: 4 15.24
## 5: 5 14.23
## Predicted values of Y:
## [1] 0.7037 5.8848 14.7737 53.2510 68.8066
## SHAP value for X:
## x1 x2 BIAS
## [1,] -24.09 -3.893 28.68
## [2,] -19.28 -3.522 28.68
## [3,] -18.88 4.974 28.68
## [4,] 22.36 2.211 28.68
## [5,] 41.70 -1.578 28.68</code></pre>
<div id="htmlwidget-4" style="width:768px;height:960px;" class="grViz html-widget"></div>
<script type="application/json" data-for="htmlwidget-4">{"x":{"diagram":"digraph {\n\ngraph [layout = \"dot\",\n rankdir = \"LR\"]\n\nnode [color = \"DimGray\",\n style = \"filled\",\n fontname = \"Helvetica\"]\n\nedge [color = \"DimGray\",\n arrowsize = \"1.5\",\n arrowhead = \"vee\",\n fontname = \"Helvetica\"]\n\n \"1\" [label = \"Tree 4\nx1\nCover: 5\nGain: 679.732544\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"2\" [label = \"x1\nCover: 3\nGain: 8.1790123\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"3\" [label = \"Leaf\nCover: 2\nValue: 51.1111107\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"4\" [label = \"Leaf\nCover: 1\nValue: 0.555555582\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"5\" [label = \"Leaf\nCover: 2\nValue: 7.77777815\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"6\" [label = \"Tree 3\nx1\nCover: 5\nGain: 1150.75732\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"7\" [label = \"x1\nCover: 3\nGain: 12.6831322\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"8\" [label = \"Leaf\nCover: 2\nValue: 65.9259262\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"9\" [label = \"Leaf\nCover: 1\nValue: 0.666666687\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"10\" [label = \"Leaf\nCover: 2\nValue: 9.62963009\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"11\" [label = \"Tree 2\nx1\nCover: 5\nGain: 764.459351\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"12\" [label = \"x2\nCover: 4\nGain: 74.133316\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"13\" [label = \"Leaf\nCover: 1\nValue: 66.6666641\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"14\" [label = \"x1\nCover: 2\nGain: 0.444444656\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"15\" [label = \"Leaf\nCover: 2\nValue: 19.9999981\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"16\" [label = \"Leaf\nCover: 1\nValue: 0.666666627\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"17\" [label = \"Leaf\nCover: 1\nValue: 3.33333325\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"18\" [label = \"Tree 1\nx1\nCover: 5\nGain: 2506.3335\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"19\" [label = \"x2\nCover: 3\nGain: 43\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"20\" [label = \"Leaf\nCover: 2\nValue: 100\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"21\" [label = \"x1\nCover: 2\nGain: 1\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"22\" [label = \"Leaf\nCover: 1\nValue: 20\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"23\" [label = \"Leaf\nCover: 1\nValue: 1\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"24\" [label = \"Leaf\nCover: 1\nValue: 5\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"25\" [label = \"Tree 0\nx1\nCover: 5\nGain: 2506.3335\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"26\" [label = \"x2\nCover: 3\nGain: 43\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"27\" [label = \"Leaf\nCover: 2\nValue: 100\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"28\" [label = \"x1\nCover: 2\nGain: 1\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"29\" [label = \"Leaf\nCover: 1\nValue: 20\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"30\" [label = \"Leaf\nCover: 1\nValue: 1\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"31\" [label = \"Leaf\nCover: 1\nValue: 5\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n\"1\"->\"2\" [label = \"< 0.835661292\", style = \"bold\"] \n\"2\"->\"4\" [label = \"< 0.348277211\", style = \"bold\"] \n\"6\"->\"7\" [label = \"< 0.835661292\", style = \"bold\"] \n\"7\"->\"9\" [label = \"< 0.348277211\", style = \"bold\"] \n\"11\"->\"12\" [label = \"< 0.91174233\", style = \"bold\"] \n\"12\"->\"14\" [label = \"< 0.539770246\", style = \"bold\"] \n\"14\"->\"16\" [label = \"< 0.537941337\", style = \"bold\"] \n\"18\"->\"19\" [label = \"< 0.835661292\", style = \"bold\"] \n\"19\"->\"21\" [label = \"< 0.710262299\", style = \"bold\"] \n\"21\"->\"23\" [label = \"< 0.537941337\", style = \"bold\"] \n\"25\"->\"26\" [label = \"< 0.835661292\", style = \"bold\"] \n\"26\"->\"28\" [label = \"< 0.710262299\", style = \"bold\"] \n\"28\"->\"30\" [label = \"< 0.537941337\", style = \"bold\"] \n\"1\"->\"3\" [style = \"bold\", style = \"solid\"] \n\"2\"->\"5\" [style = \"solid\", style = \"solid\"] \n\"6\"->\"8\" [style = \"solid\", style = \"solid\"] \n\"7\"->\"10\" [style = \"solid\", style = \"solid\"] \n\"11\"->\"13\" [style = \"solid\", style = \"solid\"] \n\"12\"->\"15\" [style = \"solid\", style = \"solid\"] \n\"14\"->\"17\" [style = \"solid\", style = \"solid\"] \n\"18\"->\"20\" [style = \"solid\", style = \"solid\"] \n\"19\"->\"22\" [style = \"solid\", style = \"solid\"] \n\"21\"->\"24\" [style = \"solid\", style = \"solid\"] \n\"25\"->\"27\" [style = \"solid\", style = \"solid\"] \n\"26\"->\"29\" [style = \"solid\", style = \"solid\"] \n\"28\"->\"31\" [style = \"solid\", style = \"solid\"] \n}","config":{"engine":"dot","options":null}},"evals":[],"jsHooks":[]}</script>
<ul>
<li>Similar conservative effect if I set <code>skip_drop</code> to be non-zero:</li>
</ul>
<pre class="r"><code>param_dart3 <- list(objective = 'reg:linear', nrounds = 5,
eta = 2,
max_depth = 10,
booster = 'dart',
skip_drop = 0,
rate_drop = 0.5, # doesn't matter since drop is always skipped
one_drop = 0
)
simple.xgb.output(param_dart3)</code></pre>
<pre><code>## [1] train-rmse:22.360680
## [2] train-rmse:0.000000
## [3] train-rmse:7.453556
## [4] train-rmse:9.316948
## [5] train-rmse:9.823564
## Evaluation log showing testing error:
## iter train_rmse
## 1: 1 22.361
## 2: 2 0.000
## 3: 3 7.454
## 4: 4 9.317
## 5: 5 9.824
## Predicted values of Y:
## [1] 0.8333 4.1667 16.6667 63.8889 83.3333
## SHAP value for X:
## x1 x2 BIAS
## [1,] -28.89 -4.056 33.78
## [2,] -26.11 -3.500 33.78
## [3,] -25.22 8.111 33.78
## [4,] 31.53 -1.417 33.78
## [5,] 50.97 -1.417 33.78</code></pre>
<div id="htmlwidget-5" style="width:768px;height:960px;" class="grViz html-widget"></div>
<script type="application/json" data-for="htmlwidget-5">{"x":{"diagram":"digraph {\n\ngraph [layout = \"dot\",\n rankdir = \"LR\"]\n\nnode [color = \"DimGray\",\n style = \"filled\",\n fontname = \"Helvetica\"]\n\nedge [color = \"DimGray\",\n arrowsize = \"1.5\",\n arrowhead = \"vee\",\n fontname = \"Helvetica\"]\n\n \"1\" [label = \"Tree 4\nx1\nCover: 5\nGain: 2506.3335\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"2\" [label = \"x2\nCover: 3\nGain: 43\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"3\" [label = \"Leaf\nCover: 2\nValue: 100\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"4\" [label = \"x1\nCover: 2\nGain: 1\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"5\" [label = \"Leaf\nCover: 1\nValue: 20\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"6\" [label = \"Leaf\nCover: 1\nValue: 1\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"7\" [label = \"Leaf\nCover: 1\nValue: 5\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"8\" [label = \"Tree 3\nx1\nCover: 5\nGain: 416.666656\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"9\" [label = \"Leaf\nCover: 3\nValue: 0\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"10\" [label = \"x1\nCover: 2\nGain: 416.666687\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"11\" [label = \"Leaf\nCover: 1\nValue: -50\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"12\" [label = \"Leaf\nCover: 1\nValue: 0\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"13\" [label = \"Tree 2\nx1\nCover: 5\nGain: 416.666656\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"14\" [label = \"Leaf\nCover: 3\nValue: 0\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"15\" [label = \"x1\nCover: 2\nGain: 416.666687\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"16\" [label = \"Leaf\nCover: 1\nValue: -50\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"17\" [label = \"Leaf\nCover: 1\nValue: 0\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"18\" [label = \"Tree 1\nx1\nCover: 5\nGain: 416.666656\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"19\" [label = \"Leaf\nCover: 3\nValue: 0\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"20\" [label = \"x1\nCover: 2\nGain: 416.666687\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"21\" [label = \"Leaf\nCover: 1\nValue: -50\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"22\" [label = \"Leaf\nCover: 1\nValue: 0\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"23\" [label = \"Tree 0\nx1\nCover: 5\nGain: 2506.3335\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"24\" [label = \"x2\nCover: 3\nGain: 43\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"25\" [label = \"Leaf\nCover: 2\nValue: 100\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"26\" [label = \"x1\nCover: 2\nGain: 1\", shape = \"rectangle\", fontcolor = \"black\", fillcolor = \"Beige\"] \n \"27\" [label = \"Leaf\nCover: 1\nValue: 20\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"28\" [label = \"Leaf\nCover: 1\nValue: 1\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n \"29\" [label = \"Leaf\nCover: 1\nValue: 5\", shape = \"oval\", fontcolor = \"black\", fillcolor = \"Khaki\"] \n\"1\"->\"2\" [label = \"< 0.835661292\", style = \"bold\"] \n\"2\"->\"4\" [label = \"< 0.710262299\", style = \"bold\"] \n\"4\"->\"6\" [label = \"< 0.537941337\", style = \"bold\"] \n\"8\"->\"9\" [label = \"< 0.835661292\", style = \"bold\"] \n\"10\"->\"11\" [label = \"< 0.91174233\", style = \"bold\"] \n\"13\"->\"14\" [label = \"< 0.835661292\", style = \"bold\"] \n\"15\"->\"16\" [label = \"< 0.91174233\", style = \"bold\"] \n\"18\"->\"19\" [label = \"< 0.835661292\", style = \"bold\"] \n\"20\"->\"21\" [label = \"< 0.91174233\", style = \"bold\"] \n\"23\"->\"24\" [label = \"< 0.835661292\", style = \"bold\"] \n\"24\"->\"26\" [label = \"< 0.710262299\", style = \"bold\"] \n\"26\"->\"28\" [label = \"< 0.537941337\", style = \"bold\"] \n\"1\"->\"3\" [style = \"bold\", style = \"solid\"] \n\"2\"->\"5\" [style = \"solid\", style = \"solid\"] \n\"4\"->\"7\" [style = \"solid\", style = \"solid\"] \n\"8\"->\"10\" [style = \"solid\", style = \"solid\"] \n\"10\"->\"12\" [style = \"solid\", style = \"solid\"] \n\"13\"->\"15\" [style = \"solid\", style = \"solid\"] \n\"15\"->\"17\" [style = \"solid\", style = \"solid\"] \n\"18\"->\"20\" [style = \"solid\", style = \"solid\"] \n\"20\"->\"22\" [style = \"solid\", style = \"solid\"] \n\"23\"->\"25\" [style = \"solid\", style = \"solid\"] \n\"24\"->\"27\" [style = \"solid\", style = \"solid\"] \n\"26\"->\"29\" [style = \"solid\", style = \"solid\"] \n}","config":{"engine":"dot","options":null}},"evals":[],"jsHooks":[]}</script>
<p>Letting <code>one_drop</code> = 1 also gives result more conservative, and smaller train-rmse if use same rounds.</p>
</div>
</div>